import numpy as np
import time
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from torch.nn import Parameter

class HCPN(nn.Module):
    def __init__(self, data_dim, dim_proto, dim_cls, num_class, n_atom_alloc, num_emb_a, num_emb_r, n_emb_a_select, n_emb_r_select, atom_t, w_dr_intra, w_dr_inter, w_obj_shr, dr_dis, num_atom_r, devices):
        super(HCPN, self).__init__()
        self.a_o_att_w_GAT = nn.Linear(dim_proto, dim_proto)
        self.a_o_att_a_GAT = nn.Linear(dim_proto * 2, 1)
        self.d_prot_a = dim_proto
        self.d_prot_c = dim_cls
        self.emb_attr = [Parameter(torch.empty(data_dim, dim_proto, device=devices).uniform_(-np.sqrt(1./data_dim), np.sqrt(1./data_dim))) for i in range(num_emb_a)]
        self.emb_rela = [Parameter(torch.empty(data_dim, dim_proto, device=devices).uniform_(-np.sqrt(1. / data_dim), np.sqrt(1. / data_dim)))]
        self.w_atom_r = Parameter(torch.tensor(0.0, device=devices), requires_grad=False) # weight for regulation contribution ratio within pair features
        #self.mask_da = utils.block_diag(num_emb_a*dim_proto, dim_proto) # denoting distance within each embedding matrix
        #self.atom_shrink = nn.Linear(dim_proto*(num_emb_a+num_emb_r), dim_proto)
        self.num_emb_a = n_emb_a_select # how many to select from all embs
        self.num_emb_r = n_emb_r_select
        self.c_emb_attr_id = 0 # from which emb matrix is being used currently
        self.c_emb_rela_id = 0
        self.c_emb_attr_id_end = None
        self.c_emb_rela_id_end = None
        self.emb_attr_id_rec = [0,num_emb_a]
        self.emb_rela_id_rec = [0,1]
        self.div_reg_t = 0.9 # threshold for forcing emb matrices to be orthogonal
        self.prototypes = utils.Component_prototypes(dim_proto, dim_cls, n_atom_alloc, n_emb_a_select, n_emb_r_select)
        self.classifier_simp_atom = nn.Linear(dim_proto*(n_emb_a_select+n_emb_r_select), num_class) # a simplified classifier
        self.classifier_simp_emb = nn.Linear(dim_proto, num_class)  # a simplified classifier
        self.classifier_simp_obj = nn.Linear(dim_proto, num_class)
        #self.classifier_simp_obj = nn.Linear(dim_proto*(num_emb_a+num_emb_r), num_class)
        self.classifier_simp_ao = nn.Linear(dim_proto*(n_emb_a_select+n_emb_r_select+1), num_class)
        self.classifier_simp_aoc = nn.Linear(dim_proto*(num_emb_a+num_emb_r+1)+self.d_prot_c, num_class)
        #self.classifier_simp_ao = nn.Linear(dim_proto * 2, num_class)
        self.classifier_atten_GAT = utils.atten_classifier_GAT_o(dim_proto, dim_proto, num_class, inte='concat')
        #self.classifier_lstm = utils.classifier_lstm(num_class, dim_proto*(n_emb_a_select+n_emb_r_select), dim_proto)
        self.num_class = num_class
        self.criterion = nn.CrossEntropyLoss() #nn.BCEWithLogitsLoss()
        self.emb_pro_dis_loss = nn.MSELoss()
        self.atom_t = atom_t
        self.w_dr_intra = w_dr_intra # scaling factor of penalty on atom diversity
        self.w_dr_inter = w_dr_inter
        self.w_obj_shr = w_obj_shr # scaling factor of object shrink loss
        self.dr_dis = dr_dis # min distance used in diversity penalty
        self.sigmoid = nn.Sigmoid()
        self.batch_norm = nn.BatchNorm1d(dim_proto)
        self.relu = nn.ReLU()
        self.device = devices
        self.data_dim = data_dim
        self.ids_record = []
        self.proto_ids_record = []
        self.num_atom_r = num_atom_r
        self.task_id_c = 0

    def forward(self, data, c_ids, est_proto, est_obj=False, proto_cls=False, task_id = None):
        # data prepare
        train_ids, valida_ids, test_ids, graph, multi_nbs, features, y_train, y_val, y_test, labels = data
        c_labels = labels[c_ids]
        nb_ids = [multi_nbs[id] for id in c_ids]
        nei_ids_sampled = utils.lil_sample(nb_ids, self.num_atom_r, flatten=True)
        nei_ids_sampled = np.array(nei_ids_sampled).reshape(-1)

        c_feats = torch.tensor(features[c_ids], dtype=torch.float, device=self.device)  # the features of current nodes [batch_size, dim_feats]
        batch_size = c_feats.shape[0]
        c_feats = c_feats.view(batch_size, 1, -1)

        nei_feats = torch.tensor(features[nei_ids_sampled], device=self.device, dtype = torch.float).view(batch_size, sum(self.num_atom_r), -1)

        pair_feats = self.w_atom_r*c_feats.repeat([1, self.num_emb_r, 1]) + (1.0-self.w_atom_r)*nei_feats
        #self.mask_da = self.mask_da.cuda(c_feats.get_device())

        ## embedding
        '''
        emb_attr = torch.cat(self.emb_attr[self.c_emb_attr_id:], dim=1) if self.c_emb_attr_id_end is None else torch.cat(self.emb_attr[self.c_emb_attr_id:self.c_emb_attr_id_end], dim=1)
        emb_rela = torch.cat(self.emb_rela[self.c_emb_rela_id:], dim=1) if self.c_emb_rela_id_end is None else torch.cat(self.emb_rela[self.c_emb_rela_id:self.c_emb_rela_id_end], dim=1)
        attr_embs = (c_feats.view(batch_size, -1).mm(emb_attr)).view(batch_size, -1, self.d_prot_a) # embed current vertex's features [batch_size, la, dim_proto]
        rela_embs = (pair_feats.view(batch_size*self.num_emb_r, -1)).mm(emb_rela).view(batch_size, -1, self.d_prot_a)  # relational atomic embeddings, of shape [batch_size, l_r, dim_proto]
        '''
        emb_attr_selected = self.emb_attr[self.c_emb_attr_id:] if self.c_emb_attr_id_end is None else self.emb_attr[self.c_emb_attr_id:self.c_emb_attr_id_end]
        emb_rela_selected = self.emb_rela[self.c_emb_rela_id:] if self.c_emb_rela_id_end is None else self.emb_attr[self.c_emb_rela_id:self.c_emb_rela_id_end]
        emb_attr = [(c_feats.view(batch_size, -1).mm(emb)).view(batch_size, -1, self.d_prot_a) for emb in emb_attr_selected] # la * [batch, 1, d_proto]
        emb_rela = [(pair_feats.view(batch_size*self.num_emb_r, -1).mm(emb)).view(batch_size, -1, self.d_prot_a) for emb in emb_rela_selected] # n_emb_r * [batch, lr, d_proto]
        embs = [torch.cat((emb_attr[i],emb_rela[i]), dim=1) for i in range(len(emb_attr_selected))] # num_emb_r/la * [batch, lr+1, d_proto]

        '''
        emb_attr = torch.cat(self.emb_attr, dim=1) if self.c_emb_attr_id_end is None else torch.cat(self.emb_attr[self.c_emb_attr_id:self.c_emb_attr_id_end], dim=1)
        emb_rela = torch.cat(self.emb_rela, dim=1) if self.c_emb_rela_id_end is None else torch.cat(self.emb_rela[self.c_emb_rela_id:self.c_emb_rela_id_end], dim=1)
        dim_start = self.c_emb_attr_id * self.d_prot_a
        attr_embs = ((c_feats.view(batch_size, -1).mm(emb_attr)))[self.c_emb_attr_id:] if self.c_emb_attr_id_end is None else (c_feats.view(batch_size, -1).mm(emb_attr))[self.c_emb_attr_id:self.c_emb_attr_id_end]
        attr_embs = attr_embs.view(batch_size, -1, self.d_prot_a) # embed current vertex's features [batch_size, la, dim_proto]
        rela_embs = (pair_feats.view(batch_size*self.num_emb_r, -1)).mm(emb_rela)[self.c_emb_rela_id:]if self.c_emb_attr_id_end is None else (pair_feats.view(batch_size*self.num_emb_r, -1)).mm(emb_rela)[self.c_emb_attr_id:self.c_emb_attr_id_end]
        rela_embs = rela_embs.view(batch_size, -1, self.d_prot_a)  # relational atomic embeddings, of shape [batch_size, l_r, dim_proto]
        #print('starting of attr emb and rela emb are {} and {}'.format(self.c_emb_attr_id, self.c_emb_rela_id))
        '''

        atom_embs = torch.cat(embs, dim=1)  # atomical embeddings of relations of a vertex to other vertices  [batch, n_emb *(lr+1), dim_proto]
        atom_embs_n = F.normalize(atom_embs, p=2, dim=-1) # normalize each component embedding into a unit ball

        ## prototype interaction
        self.prototypes = self.prototypes.cuda(atom_embs_n.get_device())
        if task_id==None:
            task_id=self.prototypes.AFE_select(c_ids, atom_embs_n, self.atom_t, est_proto, task_id=task_id)
        associated_atoms, associated_obj_embs, associated_cls, hard_corres_atom, hard_corres_obj, selected_sorted_ids = self.prototypes.update(c_ids, atom_embs_n, self.atom_t, est_proto, task_id=task_id) # [batch_size * (la+lr), num_protos] # correspondence.mm(self.prototypes.atoms[0:n_atoms]) # [batch_size * (la+lr), dim_proto]
        id_batch = torch.tensor(range(batch_size)).view(batch_size, 1, 1)
        id_dim = torch.tensor(range(self.d_prot_a)).view(1, 1, self.d_prot_a)
        self.atom_embs = atom_embs_n[id_batch, selected_sorted_ids, id_dim]
        # record mapping between embeddings and protos
        if est_proto:
            cls_proto_map_c = []
            for i in range(self.num_class):
                cls_proto_map_c.append([])
            proto_ids = torch.argmax(hard_corres_atom, dim=1)  # record which protos are selected
            train_ids_record = np.array(c_ids).reshape(-1, 1).repeat(self.num_emb_a, 1)
            nb_ids_record = nei_ids_sampled.reshape(batch_size, self.num_emb_r)
            ids_record = np.concatenate([train_ids_record, nb_ids_record], 1)
            ids_record = ids_record.reshape(-1)
            for i,j in enumerate(ids_record):
                l = np.argmax(labels[j])
                cls_proto_map_c[l].append(proto_ids[i])
            self.ids_record.append(ids_record)
            self.proto_ids_record.append(cls_proto_map_c)

        ## classifier
        c_labels = torch.tensor([np.argmax(label) for label in c_labels], dtype=torch.long, device=atom_embs_n.get_device())

        loss_emb_ato_dis = self.emb_pro_dis_loss(associated_atoms, self.atom_embs.view(-1, self.d_prot_a))
        # loss_emb_pro_dis = utils.cos_dis_loss(preds_emb, preds_proto)

        # atomic embedding classification
        preds_emb = self.classifier_simp_atom(self.atom_embs.view(batch_size, self.d_prot_a * (self.num_emb_r + self.num_emb_a)))
        preds_emb = F.softmax(preds_emb, dim=1)
        loss_cls_emb = self.criterion(preds_emb, c_labels)
        # atom proto classification
        preds_atom = self.classifier_simp_atom(associated_atoms.view(batch_size, self.d_prot_a * (self.num_emb_r + self.num_emb_a)))
        preds_atom = F.softmax(preds_atom, dim=1)
        loss_cls_atom = self.criterion(preds_atom, c_labels)
        '''
        # obj embedding classification
        preds_obj_emb = self.classifier_simp_obj(associated_obj_embs)
        preds_obj_emb = F.softmax(preds_obj_emb, dim=1)
        loss_cls_obj_emb = self.criterion(preds_obj_emb, c_labels)
        '''
        '''
        # obj&atom co-classification (concat after shrinking atom sizes)
        associated_atoms = self.atom_shrink(associated_atoms.view(batch_size, (self.num_emb_r+self.num_emb_a)*self.d_prot_a)) # [batch_size, dim_proto]
        associated_protos = torch.cat([associated_atoms, associated_obj_embs], dim=1)
        preds_ao = self.classifier_simp_ao(associated_protos)
        preds_obj_emb = F.softmax(preds_ao, dim=1)
        loss_cls_obj_emb = self.criterion(preds_ao, c_labels)
        '''

        # obj&atom co-classification (concat)
        associated_cls = associated_cls.view(batch_size, self.d_prot_c)
        associated_obj = associated_obj_embs.view(batch_size, self.d_prot_a)
        associated_atoms = associated_atoms.view(batch_size, (self.num_emb_r+self.num_emb_a)*self.d_prot_a)
        associated_aos = torch.cat([associated_atoms, associated_obj], dim=1)
        associated_aocs = torch.cat([associated_atoms, associated_obj, associated_cls], dim=1)
        preds_aoc = self.classifier_simp_aoc(associated_aocs.view(batch_size, self.d_prot_a * (1 + self.num_emb_r + self.num_emb_a)+self.d_prot_c))
        preds_ao = self.classifier_simp_ao(associated_aos.view(batch_size, self.d_prot_a * (1 + self.num_emb_r + self.num_emb_a)))
        preds_aoc = F.softmax(preds_aoc, dim=1)
        preds_ao = F.softmax(preds_ao, dim=1)
        loss_cls_ao = self.criterion(preds_ao, c_labels)
        loss_cls_aoc = self.criterion(preds_aoc, c_labels)
        '''
        # obj&atom co-classification (lstm)
        associated_obj_embs = associated_obj_embs.view(batch_size, 1, self.d_prot_a)
        associated_atoms = associated_atoms.view(batch_size, self.num_emb_r + self.num_emb_a, self.d_prot_a)
        associated_protos = torch.cat([associated_atoms, associated_obj_embs], dim=1)
        preds_proto = self.classifier_lstm(associated_atoms, associated_obj_embs)
        preds_obj_emb = F.softmax(preds_proto, dim=1)
        loss_cls_obj_emb = self.criterion(preds_proto, c_labels)
        '''
        '''
        # obj&atom co-classification via attention
        preds_obj_emb = self.classifier_atten_GAT(associated_atoms.view(batch_size, self.num_emb_r+self.num_emb_a, self.d_prot_a), associated_obj_embs)
        loss_cls_obj_emb = self.criterion(preds_obj_emb, c_labels)
        '''

        if self.training and not est_proto:
            loss_cls_atom = torch.tensor(0., device = self.device)
            loss_cls_aoc = torch.tensor(0., device = self.device)
            loss_cls_ao = torch.tensor(0., device = self.device)
            loss_emb_ato_dis = loss_emb_ato_dis*0
        elif self.training and est_proto:
            loss_cls_atom = torch.tensor(0., device = self.device)
            loss_cls_emb = torch.tensor(0., device = self.device)
            loss_emb_ato_dis = loss_emb_ato_dis*10

        # loss computation
        diver_reg_attr = torch.tensor([], device=self.device)
        diver_reg_rela = torch.tensor([], device=self.device)
        l_rec_attr = len(self.emb_attr_id_rec)
        l_rec_rela = len(self.emb_rela_id_rec)
        for i in range(l_rec_attr-2):
            m1 = F.normalize(torch.cat(self.emb_attr[self.emb_attr_id_rec[i]:self.emb_attr_id_rec[i+1]], dim=1), p=2, dim=0)
            for j in range(i+1, l_rec_attr-1, 1):
                m2 = F.normalize(torch.cat(self.emb_attr[self.emb_attr_id_rec[j]:self.emb_attr_id_rec[j+1]], dim=1), p=2, dim=0)
                cos_dis = m1.transpose(1,0).mm(m2)
                mask = (cos_dis>self.div_reg_t).float()
                cos_dis_triu = torch.triu(cos_dis * mask, diagonal=-1)
                diver_reg_attr = torch.cat((diver_reg_attr, cos_dis_triu))
        for i in range(l_rec_rela-2):
            m1 = F.normalize(torch.cat(self.emb_rela[self.emb_rela_id_rec[i]:self.emb_rela_id_rec[i+1]], dim=1), p=2, dim=0)
            for j in range(i+1, l_rec_rela-1, 1):
                m2 = F.normalize(torch.cat(self.emb_rela[self.emb_rela_id_rec[j]:self.emb_rela_id_rec[j+1]], dim=1), p=2, dim=0)
                cos_dis = m1.transpose(1,0).mm(m2)
                mask = (cos_dis > self.div_reg_t).float()
                diver_reg_rela = torch.cat((diver_reg_rela, torch.triu(cos_dis*mask, diagonal=-1)))

        diver_reg = 0

        for i in [diver_reg_attr.mean(),diver_reg_rela.mean()]:
            if torch.isnan(i)==0:
                diver_reg = diver_reg + i
        #diver_reg = 0
        '''
        diver_reg = self.w_dr_inter*utils.Pairwise_dis_loss(self.emb_attr.weight, self.dr_dis, mask_d = 1-self.mask_da) \
                    + self.w_dr_intra*(utils.Pairwise_dis_loss(self.emb_attr.weight, self.dr_dis, mask_d = self.mask_da)
                                       + utils.Pairwise_dis_loss(self.emb_rela.weight, self.dr_dis))
        '''
        #loss_emb_ato_dis = 0
        return [loss_cls_emb, loss_cls_atom, loss_cls_aoc+loss_cls_ao, loss_emb_ato_dis, diver_reg], preds_emb, preds_atom, preds_ao, preds_aoc, self.atom_embs, associated_aocs.view(batch_size, self.d_prot_a * (1 + self.num_emb_r + self.num_emb_a)+self.d_prot_c)
        #return [loss_cls_emb, loss_cls_atom, loss_cls_obj_emb, loss_emb_ato_dis], preds_atom, preds_emb, preds_obj_emb, self.atom_embs

    def incre_emb(self, tp, n):
        # increase the number of embedding matrices of the given type
        if tp == 'attr':
            for i in range(n):
                self.emb_attr.append(Parameter(torch.empty(self.data_dim, self.d_prot_a, device=self.device).uniform_(-np.sqrt(1./self.data_dim), np.sqrt(1./self.data_dim))))
            self.emb_attr_id_rec.append(self.emb_attr_id_rec[-1]+n)
            self.c_emb_attr_id = self.emb_attr_id_rec[-2]
            self.prototypes.atom_a_splits.append(self.prototypes.num_atoms)
            print('self.emb attr rec is increased to', self.emb_attr_id_rec)
        elif tp == 'rela':
            for i in range(n):
                self.emb_rela.append(Parameter(torch.empty(self.data_dim, self.d_prot_a, device=self.device).uniform_(-np.sqrt(1./self.data_dim), np.sqrt(1./self.data_dim))))
            self.emb_rela_id_rec.append(self.emb_rela_id_rec[-1] + n)
            self.c_emb_rela_id = self.emb_rela_id_rec[-2]
            self.prototypes.atom_r_splits.append(self.prototypes.num_atoms)
            print('self.emb rela rec is increased to',self.emb_rela_id_rec)
